Conversation
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Co-authored-by: miriam-16 <miriam-16@users.noreply.github.com> Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
…inch_attention), introduce context/question handling in pipeline. Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
… step Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
…ngths across iterations Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
…en selected Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
|
Hi @giulio98, Thanks for your contribution. I started looking at the PR. Two initial remarks:
I started to look more into finch_press.py. Would it be possible to re-use components from other presses ? For instance you could delete def score(self, module, hidden_states, keys, values, attentions, kwargs):
bsz, num_key_value_heads, q_len, _ = keys.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
if attentions is not None:
attn_weights = attentions[..., -self.condition_len :, : -self.condition_len]
else:
attn_weights = SnapKVPress.compute_window_attention(module, hidden_states, keys, self.condition_len, kwargs["position_embeddings"])
if self.normalize_scores:
non_zero_counts = torch.arange(q_len - self.condition_len, q_len)[None, None, :, None]
non_zero_counts = non_zero_counts.to(attn_weights.device)
attn_weights = attn_weights * non_zero_counts
# Average per group
scores = attn_weights.mean(dim=-2)
scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.condition_len)
scores = scores.mean(dim=2)
# Add back the observation window. Use max score to make sure the window is not pruned.
scores = F.pad(scores, (0, self.condition_len), value=scores.max().item())
return scoresComparison with the initial score method and the score method I propose: Note that in the code above I replaced sum by mean to avoid too large floats when using |
| values = values.gather(2, indices).contiguous() | ||
| return keys, values | ||
|
|
||
| def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): |
There was a problem hiding this comment.
is there any update here compared to BasePress ? if not remove
…te_metrics.py, use mean instead of sum to avoid large floats, replace computed_normalization factors with torch.arange Co-authored-by: SimJeg <simjeg@users.noreply.github.com> Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
|
Hello @SimJeg instead of because it needs into account for the past cache of previous iteration, similarly I had to overwrite forward for the following line: line For same reason I couldn't use Finally, Regarding the special Giulio |
|
In bsz, num_key_value_heads, q_len, _ = keys.shapeso in fact |
Co-authored-by: SimJeg <simjeg@users.noreply.github.com> Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
You are right I replaced it with SnapKVPress compute window attention |
|
@giulio98 do you think it would make a big difference (in terms of accuracy) to move from :
to
I know it's different but the code would be much easier to read (which is my main concern so far), and maybe have similar performances |
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
My main concern about this implementation is that the chunked forward was designed in |
On the other hand I know that doing as you proposed will make easier to apply for example |
|
Indeed with a compression ratio of 50%, Finch could handle inputs up to 256k even if the LLM max size is 128k (and more generally |
|
I suspect that RULER is a dataset very information dense, so it is very hard to apply compression beyond a certain limit, instead if we imagine to compress external knowledge from let's say wikipedia pages I think will make a lot of sense KV Compression. |
|
Draft for a proposal of a simplified # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import math
from dataclasses import dataclass
import torch
from torch.nn import functional as F
from kvpress.presses.base_press import BasePress
from kvpress.presses.snapkv_press import SnapKVPress
from transformers.models.llama.modeling_llama import rotate_half
@dataclass
class FinchPress(BasePress):
compression_ratio: float = 0.0
split_size: int = 1
normalize_scores: bool = True
condition_len: int = None
def score(self, module, hidden_states, keys, values, attentions, kwargs):
"""
Similar to SnapKVPress except it adds a normalization step before averaging on the context window.
"""
bsz, num_key_value_heads, q_len, _ = keys.shape
num_key_value_groups = module.config.num_attention_heads // num_key_value_heads
if attentions is not None:
attn_weights = attentions[..., -self.condition_len :, : -self.condition_len]
else:
attn_weights = SnapKVPress.compute_window_attention(
module, hidden_states, keys, self.condition_len, kwargs["position_embeddings"]
)
if self.normalize_scores:
non_zero_counts = torch.arange(q_len - self.condition_len, q_len)[None, None, :, None]
non_zero_counts = non_zero_counts.to(attn_weights.device)
attn_weights = attn_weights * non_zero_counts
# Average per group
scores = attn_weights.mean(dim=-2)
scores = scores.view(bsz, num_key_value_heads, num_key_value_groups, q_len - self.condition_len)
scores = scores.mean(dim=2)
# Add back the observation window. Use max score to make sure the window is not pruned.
scores = F.pad(scores, (0, self.condition_len), value=scores.max().item())
return scores
def compress(self, module, hidden_states, keys, values, attentions, kwargs):
"""
Scores are computed by chunks, keys and values are then compressed and re-rotated.
"""
q_len = hidden_states.shape[1]
if self.compression_ratio == 0:
return keys, values
assert (self.condition_len is not None) and (self.condition_len < q_len)
# Compute scores
scores = self.score(module, hidden_states, keys, values, attentions, kwargs)
# Compute indices by chunks
indices = []
chunk_size = math.ceil(q_len / self.split_size)
for i, chunk_scores in enumerate(torch.split(scores, chunk_size, dim=2)):
n_kept = max(1, int(chunk_scores.shape[2] * (1 - self.compression_ratio)))
chunk_indices = i * chunk_size + chunk_scores.topk(n_kept, dim=-1).indices
indices.append(chunk_indices)
indices = torch.cat(indices, dim=-1)
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
# Rerotate keys and values
cos, sin = kwargs["position_embeddings"]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
keys = keys.gather(2, indices).contiguous()
cos, sin = cos[:, : indices.shape[2]], sin[:, : indices.shape[2]]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
values = values.gather(2, indices).contiguous()
return keys, valuesI will run it for 50% compression to compare with what you reported |
|
@SimJeg |
@SimJeg I just noticed that you may have to order the indices as in line 114 in finch_press.py, as we saw this will enhance performance. |
Signed-off-by: giulio98 <corallo.giulio@yahoo.it>
|
@giulio98 I created a new branch here: https://github.com/NVIDIA/kvpress/tree/simon/finch What it contains:
I will share results with 50% compression. Could you provide the detailed performances for each subtask you obtained ? update: So it's slightly lower than your version. Once slight difference however is that I compress the question too but might impact only ~0.5% of compression ratio. |
Hello, I have to rerun the experiment, meantime first thing I noticed is missing is the sorting of the indices just after the topk, in fact top k return indices according to higher score by default, however we may need to sort them to their natural order because otherwise they can assume different meaning (this can be one thing that can enhance performance also in other presses). |
|
Great catch ! It also impacts KeyRerotationPress so I will correct it, but it won't impact other presses as the order of keys and values does not matter. I re-ran your implementation and get 91.6. Also I made a mistake in the numbers I reported above (I reported results for 5% of the data). With the error, I get 90.8. I will correct the error and report results for 4 options: with / without - normalization / rerotation. |
Ah yes! If no rerotation is applied it is permutation invariant. |
|
Updated results. Fixing the bug led to slighly worse performance (especially for Will look again to what might be the difference.
|
| n_kept_context = int(context_length * (1 - self.compression_ratio)) | ||
| else: | ||
| past_cache_len = scores.shape[-1] - q_len | ||
| n_kept_context = int((q_len - self.condition_len) * (1 - self.compression_ratio)) + past_cache_len |
There was a problem hiding this comment.
not correct, for instance for last_iteration should be
n_kept_context = int(context_length * (1 - self.compression_ratio) - self.condition_len * self.compression_ratio)There was a problem hiding this comment.
The condition len shouldn't be compressed isn't it? Because in other presses the question is provided as is so if we compress also the question then the comparison will not be fair
There was a problem hiding this comment.
You are using the question in the input so it should be compressed too, as it's done in other presses. Finch can use the information of the window size (which is a bit unfair) to exclude the question from the compression, but it should translate to a slightly higher compression for the context.
There was a problem hiding this comment.
note that the difference is very small (~0.5% in CR) so I don't think it should impact much performance
There was a problem hiding this comment.
You are using the question in the input so it should be compressed too, as it's done in other presses. Finch can use the information of the window size (which is a bit unfair) to exclude the question from the compression, but it should translate to a slightly higher compression for the context.
Query-aware compression like finch or snapkkv and question agnostic comes with trade off. Query-aware compression gives highest performance but we have to rerun the compression for each new query, instead query agnostic like other presses can be done independently on the question giving more throughtput, this is a topic we explored in our recent paper: https://arxiv.org/pdf/2503.04973 where we propose a middle ground and compress using just task and few shot examples. That being said the final budget of tokens should be equal for all the approaches for a fair comparison IMO. What do you think?
There was a problem hiding this comment.
but we have to rerun the compression for each new query
Approaches like SnapKV or Finch are close to sparse attention (i.e. retrieve the right KV pairs from the full KV cache) and somehow a bit different from compression, because as you mention you have to re-run them for each query.
That being said the final budget of tokens should be equal for all the approaches for a fair comparison IMO
I agree that's why I proposed this update. Could you review #69 and comment ? You can also open a new PR with the same code if you want to appear in the contributors of the repo
There was a problem hiding this comment.
Would be possible to add Co-authored by in your commits adding myself @miriam-16 and @FaureElia ?
There was a problem hiding this comment.
Done, could you please close this branch ?
|
Ok I found the difference, in your implementation the |

PR description
This pull request introduces the FinchPress implementation (#59), incorporating chunked forward propagation, normalization of scores, and key re-rotation in alignment with the authors' original specifications.
The provided implementation has been thoroughly validated through extensive testing against the original reference code from the authors, ensuring a precise 1-to-1 mapping in functionality.
Below a plot of FinchPress in comparison with SnapKVPress(w/ question) on RULER-4k.

Checklist
QFilterPressgit commit -smypress_press.pyis in thepressesdirectoryMyPressis in__init__.pyREADME.mdis updated with a 1 liner about the new press in the Available presses sectiondefault_presseslist intests/default_presses.py